import numpy as np
import pytest
from scipy.stats import norm


def test_quantile_function():
    from soe.utils import QuantileFunction

    # small sample
    samples = np.array([1, 2, 3, 4, 5])
    quantile_function = QuantileFunction(samples)

    assert quantile_function(0.0) == 1
    assert quantile_function(0.25) == 2
    assert quantile_function(0.5) == 3
    assert quantile_function(1.0) == 5
    assert np.allclose(quantile_function([0.0, 0.25, 0.5, 0.75, 1.0]), samples)

    # uniform distribution
    samples = np.random.rand(int(1e6))
    quantile_function = QuantileFunction(samples)

    assert np.abs(quantile_function(0.5) - 0.5) < 1e-2
    assert quantile_function(0.0) == np.min(samples)
    assert quantile_function(1.0) == np.max(samples)

    # normal distribution
    samples = np.random.randn(int(1e6))
    quantile_function = QuantileFunction(samples)

    p = np.linspace(0.01, 0.99, 99)
    assert np.abs(norm.ppf(p) - quantile_function(p)).mean() < 1e-2


def test_second_quantile():
    from soe.utils import SecondQuantileFunction

    # Small sample
    samples = np.array([1, 2, 3, 4])
    sec_quant = SecondQuantileFunction(samples)

    assert sec_quant(0.0) == 0.0

    # Normal distribution
    samples = np.random.randn(int(1e6))
    sec_quant = SecondQuantileFunction(samples)

    assert sec_quant(0.0) == 0.0

    dx = sec_quant.dp
    x = np.linspace(0.0, 1.0, int(1.0 / dx))[1:-1]
    sec_quant_true = np.cumsum(norm.ppf(x) * dx)
    assert np.abs(sec_quant_true - sec_quant(x)).max() < 1e-2


def test_ecdf():
    from soe.utils import ECDF

    # small sample
    samples = np.array([1, 2, 3, 4, 5])
    cdf = ECDF(samples)

    assert cdf(0.0) == 0.0
    assert np.allclose(cdf(samples), np.arange(1, 6) / 5)

    # normal distribution
    samples = np.random.randn(int(1e6))
    cdf = ECDF(samples)

    x = np.linspace(-5, 5, 100)
    assert np.mean(np.abs(norm.cdf(x) - cdf(x))) < 1e-2


# Run test with mulitple values of x0 and x1
@pytest.mark.parametrize("x0, x1", [(0.0, 1.0), (2.0, 4.0)])
def test_num_integrate(x0, x1):
    from soe.utils import num_integrate, num_integrate_func

    dx = 0.001
    f = lambda x: x
    int_f = lambda x: x**2 / 2

    # Compare with analytical result: $\Int x dx = x^2 / 2$
    res = num_integrate(f, x0, x1, dx)
    gt = int_f(x1) - int_f(x0)
    assert np.abs(res - gt) < 1e-2

    # Test indefinite integral
    num_int_f = num_integrate_func(f, x0, x1, dx)
    for xi in np.linspace(x0, x1, 10):
        gt = int_f(xi) - int_f(x0)
        assert np.abs(num_int_f(xi) - gt) < 1e-2
